#!/usr/bin/env python3

# Copyright Jamie Iles, 2017
#
# This file is part of s80x86.
#
# s80x86 is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# s80x86 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with s80x86.  If not, see <http://www.gnu.org/licenses/>.

import os
import pprint
import pystache
import struct
import yaml

from py8086sim.Cpu import RTLCPU, GPR, Flag

HERE = os.path.dirname(__file__)
INSTR_TEMPLATE = os.path.join(HERE, '..', 'documentation', 'instructions.templ')

pystache.defaults.DELIMITERS = (u'<%', u'%>')
pystache.defaults.TAG_ESCAPE = lambda u: u

with open(os.path.join(HERE, '..', 'documentation', 'instructions.yaml')) as instr:
    instructions = yaml.load(instr, Loader=yaml.FullLoader)

def has_modrm(operands):
    return any(map(lambda x: x != 'ES' and x.startswith('E'), operands))

def has_immediate(operands):
    return any(map(lambda x: x.startswith('I'), operands))

def has_mem_only(operands):
    return any(map(lambda x: x.startswith('M'), operands))

def to_raw_reg(gpr):
    if gpr >= GPR.AL:
        return int(gpr) - int(GPR.AL)
    return int(gpr)

def modrm(reg, encoding, cpu, timing_operands, rm_reg=None):
    reg = int(encoding.get('reg', reg))
    operands = encoding.get('operands', [])

    if timing_operands:
        for i, o in enumerate(operands):
            if o.startswith('E') and rm_reg:
                cpu.write_reg(GPR.values[rm_reg], timing_operands[i])
            elif o.startswith('E') and not rm_reg:
                cpu.write_reg(GPR.BX, 0x100)
                if o.endswith('b'):
                    cpu.write_mem8(cpu.read_reg(GPR.DS), 0x100, timing_operands[i])
                elif o.endswith('w'):
                    cpu.write_mem16(cpu.read_reg(GPR.DS), 0x100, timing_operands[i])
                else:
                    cpu.write_mem32(cpu.read_reg(GPR.DS), 0x100, timing_operands[i])
            elif o.startswith('G'):
                cpu.write_reg(GPR.values[reg], timing_operands[i])

    if rm_reg:
        return (3 << 6) | (to_raw_reg(reg) << 3) | to_raw_reg(rm_reg)

    return to_raw_reg(reg) << 3 | 7

def immediate(operands, timing_operands):
    imm_bytes = []
    for i, o in enumerate(operands):
        if o[0] not in 'IJ':
            continue

        v = 0
        if timing_operands:
            v = int(timing_operands[i])

        if o.endswith('b'):
            struct_fmt = '<B'
        elif o.endswith('w'):
            struct_fmt = '<H'
        else:
            struct_fmt = '<I'

        packed = struct.pack(struct_fmt, v)
        for b in packed:
            assert type(b) == int
            imm_bytes.append(b)
    return imm_bytes

def describe_operands(operands, timing_operands, E='memory'):
    names = []

    for o in operands:
        name = ''

        if o in GPR.names:
            names.append(o)
            continue

        for c in o:
            if c == 'E':
                name += E
            elif c == 'G':
                name += 'register'
            elif c == 'M':
                name += 'memory'
            elif c == 'I':
                name += 'immediate'
            elif c == 'J':
                name += 'displacement'
            elif c == 'O':
                name += 'moffset'
            elif c == 'w':
                name += '16'
            elif c == 'b':
                name += '8'
            elif c == 'p':
                name += '(seg:offset)'
            else:
                raise ValueError("Invalid operand type %s %c" % (o,c ))

        names.append(name)

    if timing_operands:
        return names, '(ex: ' + ' '.join(['0x{0:x}'.format(t) for t in timing_operands]) + ')'

    return names, None

def encode(encoding, cpu):
    cpu.reset()

    instr = [encoding['opcode']]
    operands = encoding.get('operands', [])

    timing_operands = encoding.get('timing_operands', [None])

    for timing in timing_operands:
        immed = immediate(operands, timing)
        if 'operands' in encoding and has_mem_only(operands):
            yield describe_operands(operands, timing), instr + [modrm(GPR.AX, encoding, cpu, timing)] + immed
        elif 'operands' in encoding and has_modrm(operands):
            gpr = GPR.BL if 'Eb' in operands else GPR.BX
            yield describe_operands(operands, timing, E='register'), instr + [modrm(GPR.AX, encoding, cpu, timing, rm_reg=gpr)] + immed
            yield describe_operands(operands, timing, E='memory'), instr + [modrm(GPR.AX, encoding, cpu, timing)] + immed
        else:
            yield describe_operands(operands, timing), instr + immed

def count_cycles(cpu, instr, note=None, skip=False):
    if skip:
        return '-'

    cpu.write_vector8(cpu.read_reg(GPR.CS),
                        cpu.read_reg(GPR.IP), instr)
    cpu.write_reg(GPR.CS, cpu.read_reg(GPR.CS))
    cpu.idle(256)

    count = cpu.time_step()
    if not note:
        return count

    return '{0} {1}'.format(count, note)

def describe_modrm(encoding):
    if 'reg' in encoding:
        return ' /{0}'.format(encoding['reg'])
    elif has_modrm(encoding.get('operands', [])):
        return ' /r'
    return ''

def describe_immediates(encoding):
    return ' ' + ', '.join(filter(lambda x: x[0] == 'I', encoding.get('operands', [])))

def setup(cpu, params):
    flags = 0

    for key, val in params.items():
        if key in GPR.names:
            cpu.write_reg(GPR.names[key], int(val))
        elif key in Flag.names and int(val) != 0:
            flags |= Flag.names[key]

    cpu.write_flags(flags)

definitions = []
for mnemonic in sorted(instructions.keys()):
    definition = instructions[mnemonic]
    encodings = []

    for encoding in definition['encodings']:
        cpu = RTLCPU('instructions')

        for operand_description, instr in encode(encoding, cpu):
            opcode = '{0:02x}'.format(encoding['opcode'])
            opcode += describe_modrm(encoding)
            opcode += describe_immediates(encoding)

            size = len(instr)
            # Optional immediate added to form EA
            if has_modrm(encoding.get('operands', [])):
                size = '{0}-{1}'.format(size, size + 2)

            prefix = []
            if 'use_prefix' in encoding:
                prefix = [int(encoding['use_prefix'])]
            if 'taken' in encoding:
                setup(cpu, encoding['taken'])
                taken = count_cycles(cpu, instr, note=encoding.get('note', ''))
                setup(cpu, encoding['not_taken'])
                not_taken = count_cycles(cpu, instr, note=encoding.get('note', ''))
                cycles = '{0} (taken) / {1} (not taken)'.format(taken, not_taken)
            elif 'base' in encoding:
                setup(cpu, encoding['base'])
                base = count_cycles(cpu, prefix + instr, note=encoding.get('note', ''))
                setup(cpu, encoding['next'])
                n = count_cycles(cpu, prefix + instr, note=encoding.get('note', ''))
                cycles = '{0} + (n-1)*{1}'.format(base, n-base)
            else:
                cycles = count_cycles(cpu, instr,
                                      note=encoding.get('note', ''),
                                      skip=encoding.get('notime', False))

            desc = mnemonic + ' ' + ', '.join(operand_description[0])
            if operand_description[1]:
                desc += ' ' + operand_description[1]
            encodings.append({
                'opcode': opcode,
                'mnemonic': desc,
                'cycles': cycles,
                'size': size
            })

        # Verilator uses a singleton for some things, make sure to gc the
        # cpu before creating a new one
        cpu = None

    idef = {
        'mnemonic': mnemonic.upper(),
        'encodings': encodings,
        'flags': definition.get('flags', {}),
        'notes': [],
    }

    for note in definition.get('notes', []):
        idef['notes'].append({'note': note})

    definitions.append(idef)

with open(INSTR_TEMPLATE) as template:
    print(pystache.render(template.read(), { 'instructions': definitions }))
